# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import builtins
import os
from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest

from olive.model import ONNXModelHandler
from olive.passes.olive_pass import create_pass_from_dict
from olive.passes.onnx.qairt.mha2sha import QairtMHA2SHA
from olive.passes.pass_config import PassConfigParam
from test.utils import get_onnx_model


# Mock OnnxModel for external QAIRT SDK
class MockOnnxModelInstance:
    """A mock instance for qti.aisw.tools.core.utilities.framework.onnx.OnnxModel."""

    def __init__(self, model_path):
        self.mock_mha2sha_v2 = MagicMock()
        self.mock_mha2sha = MagicMock()
        self._export_mock_internal = MagicMock()
        self.has_mha2sha_v2 = True  # Default to having v2 for most tests

    def mha2sha_v2(self, **kwargs):
        if self.has_mha2sha_v2:
            self.mock_mha2sha_v2(**kwargs)
        else:
            # Simulate AttributeError if mha2sha_v2 is not supposed to exist
            raise AttributeError("mha2sha_v2 not available")

    def mha2sha(self, **kwargs):
        self.mock_mha2sha(**kwargs)

    def export(self, output_path, prefix):
        # Call the internal mock for tracking purposes
        self._export_mock_internal(output_path, prefix)

        # Actual file writing logic
        output_dir = Path(output_path)
        output_dir.mkdir(parents=True, exist_ok=True)  # Ensure directory exists
        file_name = f"{prefix}.onnx"
        output_file_path = output_dir / file_name

        with open(output_file_path, "w") as f:
            f.write(f"This is a dummy ONNX file for {prefix}.onnx\n")
            f.write(f"Generated by MockOnnxModelInstance at {output_file_path}\n")


@pytest.fixture(name="qairt_pass_instance")
def qairt_pass_instance_fixture():
    """Provide an instance of the QairtMHA2SHA pass."""
    return create_pass_from_dict(QairtMHA2SHA, {}, disable_search=True)


@pytest.fixture(name="mock_accelerator_spec")
def mock_accelerator_spec_fixture():
    """Provide a mock AcceleratorSpec."""
    return MagicMock()


@pytest.fixture(name="tmp_output_dir")
def tmp_output_dir_fixture(tmp_path):
    """Provide a temporary output directory as a string."""
    return str(tmp_path)


@pytest.fixture(name="mock_qairt_sdk_classes")
def mock_qairt_sdk_classes_fixture():
    """Mock the qti.aisw.tools.core.utilities.framework.onnx.OnnxModel import.

    Returns the mock class itself, allowing tests to configure its return_value for .load().
    """

    # Create the mock instance provider function
    def _create_mock_onnx_model_instance(model_path):
        return MockOnnxModelInstance(model_path)

    # Create MagicMocks for the OnnxModel class itself
    mock_onnx_model_class_new = MagicMock()
    mock_onnx_model_class_old = MagicMock()

    # Configure the .load() method of the mocked OnnxModel to return our custom mock instance
    mock_onnx_model_class_new.load.side_effect = _create_mock_onnx_model_instance
    mock_onnx_model_class_old.load.side_effect = _create_mock_onnx_model_instance

    # Patch sys.modules for the new import path
    patcher_new = patch.dict(
        "sys.modules",
        {"qti.aisw.tools.core.utilities.framework.frameworks.onnx": MagicMock(OnnxModel=mock_onnx_model_class_new)},
    )
    _ = patcher_new.start()

    # Patch sys.modules for the old import path
    patcher_old = patch.dict(
        "sys.modules", {"qti.aisw.tools.core.utilities.framework.onnx": MagicMock(OnnxModel=mock_onnx_model_class_old)}
    )
    _ = patcher_old.start()

    yield mock_onnx_model_class_new  # Yield the mock for the 'new' path, as that's what will be tried first

    # Teardown: Stop the patchers
    patcher_new.stop()
    patcher_old.stop()


def test_mha2sha_default_config(mock_accelerator_spec):
    """Test that the default config is correctly generated."""
    config = QairtMHA2SHA._default_config(mock_accelerator_spec)  # pylint: disable=protected-access
    assert "mha2sha_kwargs" in config
    assert isinstance(config["mha2sha_kwargs"], PassConfigParam)
    assert config["mha2sha_kwargs"].default_value is None


def test_mha2sha_for_onnx_model_handler(qairt_pass_instance, tmp_output_dir, mock_qairt_sdk_classes):
    """Test run with a single ONNXModelHandler."""
    input_model = get_onnx_model()
    transformed_model = qairt_pass_instance.run(input_model, tmp_output_dir)

    # Assertions
    assert isinstance(transformed_model, ONNXModelHandler)
    assert os.path.dirname(transformed_model.model_path) == tmp_output_dir
    assert transformed_model.onnx_file_name == input_model.onnx_file_name

    # Verify QAIRT SDK calls
    mock_qairt_sdk_classes.load.assert_called_once_with(model_path=input_model.model_path)

    # Retrieve the mock OnnxModelInstance that was actually returned and used
    loaded_qairt_instance = mock_qairt_sdk_classes.load.call_args.return_value
    loaded_qairt_instance.mock_mha2sha_v2.assert_called_once_with()  # No kwargs passed by default
    loaded_qairt_instance.mock_mha2sha.assert_not_called()  # Ensure V1 is not called
    loaded_qairt_instance.mock_export.assert_called_once_with(tmp_output_dir, prefix=input_model.onnx_file_name)


def test_mha2sha_v1_fallback(qairt_pass_instance, tmp_output_dir, mock_qairt_sdk_classes):
    """Test that the pass falls back to mha2sha (v1) if v2 is not available."""
    dummy_model = get_onnx_model()

    original_side_effect_func = mock_qairt_sdk_classes.load.side_effect

    def custom_side_effect_func(model_path):
        instance = original_side_effect_func(model_path)
        instance.has_mha2sha_v2 = False
        return instance

    mock_qairt_sdk_classes.load.side_effect = custom_side_effect_func

    _ = qairt_pass_instance.run(dummy_model, output_model_path=tmp_output_dir)

    # Verify V1 was called and V2 was not
    loaded_qairt_instance = mock_qairt_sdk_classes.load.call_args
    loaded_qairt_instance.mock_mha2sha.assert_called_once_with()
    loaded_qairt_instance.mock_mha2sha_v2.assert_not_called()


def test_mha2sha_kwargs_passed(tmp_output_dir, mock_qairt_sdk_classes):
    """Test that additional kwargs are passed to mha2sha_v2/mha2sha."""
    dummy_model = get_onnx_model()
    mha2sha_pass = create_pass_from_dict(
        QairtMHA2SHA, {"mha2sha_kwargs": {"param1": "value1", "param2": 123}}, disable_search=True
    )
    mha2sha_pass.run(dummy_model, output_model_path=tmp_output_dir)

    # Test with v2 available
    loaded_qairt_instance_v2 = mock_qairt_sdk_classes.load.call_args.return_value  # Get instance from previous call
    loaded_qairt_instance_v2.mock_mha2sha_v2.assert_called_once_with(param1="value1", param2=123)
    loaded_qairt_instance_v2.mock_mha2sha.assert_not_called()

    # Test with v1 fallback
    # Reset mocks for next part of test
    mock_qairt_sdk_classes.load.reset_mock()
    loaded_qairt_instance_v2.mock_mha2sha_v2.reset_mock()
    loaded_qairt_instance_v2.mock_mha2sha.reset_mock()

    # Configure the mock OnnxModel instance to NOT have mha2sha_v2 for this part
    # We reuse the logic from test_mha2sha_v1_fallback
    original_side_effect_func = mock_qairt_sdk_classes.load.side_effect

    def custom_side_effect_func_v1(model_path):
        instance = original_side_effect_func(model_path)
        instance.has_mha2sha_v2 = False
        return instance

    mock_qairt_sdk_classes.load.side_effect = custom_side_effect_func_v1

    mha2sha_pass.run(dummy_model, output_model_path=tmp_output_dir)
    loaded_qairt_instance_v1 = mock_qairt_sdk_classes.load.call_args.return_value
    loaded_qairt_instance_v1.mock_mha2sha.assert_called_once_with(param1="value1", param2=123)
    loaded_qairt_instance_v1.mock_mha2sha_v2.assert_not_called()


def test_import_error_qti_aisw(qairt_pass_instance, tmp_output_dir):
    """Test that ImportError is raised if qti.aisw.tools cannot be imported.

    This test needs to run *without* the `mock_qairt_sdk_classes` fixture's patching active.
    We'll manually clear sys.modules for this specific test.
    """

    def import_side_effect(name, *args, **kwargs):
        if name in [
            "qti.aisw.tools.core.utilities.framework.frameworks.onnx",
            "qti.aisw.tools.core.utilities.framework.onnx",
        ]:
            raise ImportError("Mock import error")
        return original_import(name, *args, **kwargs)

    original_import = builtins.__import__

    with patch("builtins.__import__", side_effect=import_side_effect):
        dummy_model = get_onnx_model()
        with pytest.raises(ImportError):
            qairt_pass_instance.run(dummy_model, output_model_path=tmp_output_dir)
